library(distributions3)
library(MASS)
library(LaplacesDemon)
library(matrixStats)
library(expm)
library(ggplot2)
library(Hmisc)

rm(list=ls()) # clean environment

###################### REPLICATION INSTRUCTION ################################
# set Figure1 variable below to 'a','b',...,'n' to reproduce 
# subplots (a),...,(n) of Figure 1 in our paper; the histograms (j)---(n)
# are simulated three times; we take first plot to present in the paper
Figure1='l'
###############################################################################


#### DEFINE VARIABLES AND SET DEFAULT VALUES ####
set.seed(1)

# model outcome under treatment by normal-mean model
mu01=0 # prior mean for treatment outcome
var01=100 # prior var for treatment outcome
varmd1=10 # model outcome var for treatment outcome
varmd1_true=10 # true outcome var for treatment outcome

# model outcome under control by normal-mean model
mu00=0 # prior meaan for control outcome
var00=100 # prior var for control outcome
varmd0=10 # model outcome var for control outcome
varmd0_true=10 # true outcome var for treatment outcome

T=10 # T-stage testing
is_dynamic=0 # whether or not using dynamic schedule, if yes, specify the schedule

Narr=rep(500,T) # total user at each stage

leg_pos='bottomright'
emp_bayes=0 # whether we use empirical Bayes to estimate outcome variance

I=1000 # number of Monte-Carlo trials (for testing--not part of the method)
is_print=0 # Whether to print diagnostics info

bandit='ours' # the bandit model being tested: 'ours' is our model and 'thompson' is thompson Bayesian bandit

plttype=1 # 1: plot ramp schedule, 2: budget surplus, 3: budget distribution
plt_num=3 # number of plots
line_color=c('black', 'red', 'blue') # color of line plots
Tplt=10 # set plotting range
ymin=0 # set plotting range
ymax=250 # set plotting range

#### SET PLOTTING AND HYPERPARAMTERS FOR EACH SUBPLOTS OF FIGURE 1 #####
if (Figure1=='a') # Experiment(Pos_ramp): PTE, ramp schedule
{
  T=10
  dg_mode=1 # treatment mean is 1, control mean is 0
  mu_true1=1
  mu_true0=0
  add_plt=c(0,1,2,3,4,5)
  plt_num=6
  is_dynamic=1
  Budget=c(-100, -100, -500, -500, -500, -500) # total budget for each trial
  s1=array(data=rep(Budget[1], times=T)) # the following sets budget schedule for each trial
  s2=array(data=rep(Budget[2], times=T))
  s3=array(data=rep(Budget[3], times=T))
  s4=array(data=rep(Budget[4], times=T))
  s5=array(data=rep(Budget[5], times=T))
  s6=c(array(data=rep(Budget[6]+100, times=5)),array(data=rep(Budget[6], times=T-5)))
  s_plt=rbind(s1,s2,s3,s4,s5,s6)
  delta1=array(data=rep(1-(1-0.01)**(1/T), times=T)) # the following sets tolerance schedule for each trial
  delta2=array(data=rep(1-(1-0.2)**(1/T), times=T))
  delta3=delta1
  delta4=delta2
  delta5=c(array(data=rep(0.0001, times=5)),array(data=rep(0.001908, times=5)))
  delta6=delta1
  deltao_plt=rbind(delta1,delta2,delta3,delta4,delta5,delta6)
  plttype=1 # the following sets plotting parameters
  I=500
  line_color=c('black', 'red', 'blue','green','orange', 'brown')
  leg_pos='bottomright'
  legendtxt=c('(-100, 0.01)', '(-100, 0.2)', '(-500, 0.01)', '(-500, 0.2)','ration tolerance', 'ration budget')
}else if (Figure1=='b') # Experiment(Neg_ramp): NTE, ramp
{
  T=10
  dg_mode=1 # treatment mean is 0, control mean is 1
  mu_true1=0
  mu_true0=1
  add_plt=c(0,1,2,3,4,5)
  plt_num=6
  is_dynamic=1
  Budget=c(-100, -100, -500, -500, -500, -500) # total budget for each trial
  s1=array(data=rep(Budget[1], times=T)) # the following sets budget schedule for each trial
  s2=array(data=rep(Budget[2], times=T))
  s3=array(data=rep(Budget[3], times=T))
  s4=array(data=rep(Budget[4], times=T))
  s5=array(data=rep(Budget[5], times=T))
  s6=c(array(data=rep(Budget[6]+100, times=5)),array(data=rep(Budget[6], times=T-5)))
  s_plt=rbind(s1,s2,s3,s4,s5,s6)
  delta1=array(data=rep(1-(1-0.01)**(1/T), times=T)) # the following sets tolerance schedule for each trial
  delta2=array(data=rep(1-(1-0.2)**(1/T), times=T))
  delta3=delta1
  delta4=delta2
  delta5=c(array(data=rep(0.0001, times=5)),array(data=rep(0.001908, times=5)))
  delta6=delta1
  deltao_plt=rbind(delta1,delta2,delta3,delta4,delta5,delta6)
  plttype=1 # the following sets plotting parameters
  I=500
  line_color=c('black', 'red', 'blue','green','orange', 'brown')
  leg_pos='topright'
  legendtxt=c('(-100, 0.01)', '(-100, 0.2)', '(-500, 0.01)', '(-500, 0.2)','ration tolerance', 'ration budget')
} else if (Figure1=='c') # Experiment (Imp_ramp): NPTE, ramp schedule
{
  T=10
  dg_mode=2
  mu_true1_sq = pmin(seq(from = -2, by = 0.5, length.out = T),2) # true mean of new feature is improving
  add_plt=c(0,1,2,3,4,5)
  plt_num=6
  is_dynamic=1
  Budget=c(-100, -100, -500, -500, -500, -500) # total budget for each trial
  s1=array(data=rep(Budget[1], times=T)) # the following sets budget schedule for each trial
  s2=array(data=rep(Budget[2], times=T))
  s3=array(data=rep(Budget[3], times=T))
  s4=array(data=rep(Budget[4], times=T))
  s5=array(data=rep(Budget[5], times=T))
  s6=c(array(data=rep(Budget[6]+100, times=5)),array(data=rep(Budget[6], times=T-5)))
  s_plt=rbind(s1,s2,s3,s4,s5,s6)
  delta1=array(data=rep(1-(1-0.01)**(1/T), times=T)) # the following sets tolerance schedule for each trial
  delta2=array(data=rep(1-(1-0.2)**(1/T), times=T))
  delta3=delta1
  delta4=delta2
  delta5=c(array(data=rep(0.0001, times=5)),array(data=rep(0.001908, times=5)))
  delta6=delta1
  deltao_plt=rbind(delta1,delta2,delta3,delta4,delta5,delta6)
  plttype=1 # the following sets plotting parameters
  Tplt=10
  I=50
  line_color=c('black', 'red', 'blue','green','orange', 'brown')
  leg_pos='topleft'
  legendtxt=c('(-100, 0.01)', '(-100, 0.2)', '(-500, 0.01)', '(-500, 0.2)','ration tolerance', 'ration budget')
}else if (Figure1=='d') # Experiment (Real_ramp): LinkedIn, ramp schedule
{
  T=6
  Narr=c(10756,10460, 10598, 7580, 10550, 10688)
  dg_mode=6
  mu_true1_sq = c(0.365879725,	0.378782722,	0.375408491,	0.231741622,	0.401021283,	0.394088418)
  mu_true0_sq = c(0.364815296, 0.378043863,0.375240757,	0.231723922,	0.400918187,	0.393006496)
  varmd1_sq = c(2.092336881,	2.224840653,	2.013513425,	1.05259011,	2.247600999,	2.443001975)
  varmd0_sq=c(2.099301341,	2.276904621,	2.090900932,	1.11645222,	2.270533675,	2.398172473)
  add_plt=c(0,1,2,3,4)
  plt_num=5
  is_dynamic=1
  Budget=c(-400, -400, -1500, -1500, -1500) # total budget for each trial
  s1=array(data=rep(Budget[1], times=T)) # the following sets budget schedule for each trial
  s2=array(data=rep(Budget[2], times=T))
  s3=array(data=rep(Budget[3], times=T))
  s4=array(data=rep(Budget[4], times=T))
  s5=c(-400, -400, -400, -400, -1500, -1500)
  s_plt=rbind(s1,s2,s3,s4,s5)
  delta1=array(data=rep(1-(1-0.01)**(1/T), times=T)) # the following sets tolerance schedule for each trial
  delta2=array(data=rep(1-(1-0.2)**(1/T), times=T))
  delta3=delta1
  delta4=delta2
  delta5=delta1
  deltao_plt=rbind(delta1,delta2,delta3,delta4,delta5)
  plttype=1 # the following sets plotting parameters
  I=500
  line_color=c('black', 'red', 'blue','green','orange', 'grey')
  leg_pos='topleft'
  ymin=0
  ymax=8500
  emp_bayes=1
}else if (Figure1=='e') #Experiment (Thomp_ramp): NPTE, TOM 
{
  T=10
  bandit='thompson'
  add_plt=c(0,1,2,3,4,5,6)
  legendtxt=c('c=0.001', 'c=0.007', 'c=0.01','c=0.02','c=0.05','c=0.1') # legend text
  dg_mode=2
  mu_true1_sq = pmin(seq(from = -2, by = 0.5, length.out = T),2)
  mu00=0
  mu01=-2
  var00=0.05
  var01=0.05
  c_arr=c(0.001,0.007,0.01,0.02,0.05,0.1)
  I=500
  plt_num=6
  plttype=1
  line_color=c('black', 'red', 'blue','green','orange', 'brown', 'grey')
  Budget=rep(-500,plt_num)
}else if (Figure1=='f') # Experiment (Thomp_real_ramp): LinkedIn, TOM
{
  T=6
  bandit='thompson'
  mu00=0
  mu01=-2
  var00=0.05
  var01=0.05
  c_arr=c(0.01,0.05,0.1,0.15,0.17,0.2)
  Narr=c(10756,10460, 10598, 7580, 10550, 10688)
  dg_mode=6
  mu_true1_sq = c(0.365879725,	0.378782722,	0.375408491,	0.231741622,	0.401021283,	0.394088418)
  mu_true0_sq = c(0.364815296, 0.378043863,0.375240757,	0.231723922,	0.400918187,	0.393006496)
  varmd1_sq = c(2.092336881,	2.224840653,	2.013513425,	1.05259011,	2.247600999,	2.443001975)
  varmd0_sq=c(2.099301341,	2.276904621,	2.090900932,	1.11645222,	2.270533675,	2.398172473)
  add_plt=c(0,1,2,3,4,5)
  plt_num=6
  is_dynamic=1
  Budget=c(-400, -400, -1500, -1500, -1500,-1500) # total budget for each trial
  plttype=1 # the following sets plotting parameters
  I=500
  line_color=c('black', 'red', 'blue','green','orange', 'brown', 'grey')
  leg_pos='topleft'
  ymin=0
  ymax=8500
  emp_bayes=1
}else if (Figure1=='g') # Experiment(Neg_budget_trend): NTE, budget surplus
{
  T=10
  dg_mode=1 # treatment mean is 0, control mean is 1
  mu_true1=0
  mu_true0=1
  add_plt=c(0,1,2,3,4,5)
  plt_num=6
  is_dynamic=1
  Budget=c(-100, -100, -500, -500, -500, -500) # total budget for each trial
  s1=array(data=rep(Budget[1], times=T)) # the following sets budget schedule for each trial
  s2=array(data=rep(Budget[2], times=T))
  s3=array(data=rep(Budget[3], times=T))
  s4=array(data=rep(Budget[4], times=T))
  s5=array(data=rep(Budget[5], times=T))
  s6=c(array(data=rep(Budget[6]+100, times=5)),array(data=rep(Budget[6], times=T-5)))
  s_plt=rbind(s1,s2,s3,s4,s5,s6)
  delta1=array(data=rep(1-(1-0.01)**(1/T), times=T)) # the following sets tolerance schedule for each trial
  delta2=array(data=rep(1-(1-0.2)**(1/T), times=T))
  delta3=delta1
  delta4=delta2
  delta5=c(array(data=rep(0.0001, times=5)),array(data=rep(0.001908, times=5)))
  delta6=delta1
  deltao_plt=rbind(delta1,delta2,delta3,delta4,delta5,delta6)
  plttype=2 # the following sets plotting parameters
  Tplt=10
  ymin=0
  ymax=500
  I=500
  line_color=c('black', 'red', 'blue','green','orange', 'brown')
  leg_pos='topright'
  legendtxt=c('(-100, 0.01)', '(-100, 0.2)', '(-500, 0.01)', '(-500, 0.2)','ration tolerance', 'ration budget')
}else if (Figure1=='h') # Experiment (Imp_budget_trend): NPTE, budget surplus
{
  T=10
  dg_mode=2
  mu_true1_sq = pmin(seq(from = -2, by = 0.5, length.out = T),2) # true mean of new feature is improving
  add_plt=c(0,1,2,3,4,5)
  plt_num=6
  is_dynamic=1
  Budget=c(-100, -100, -500, -500, -500, -500) # total budget for each trial
  s1=array(data=rep(Budget[1], times=T)) # the following sets budget schedule for each trial
  s2=array(data=rep(Budget[2], times=T))
  s3=array(data=rep(Budget[3], times=T))
  s4=array(data=rep(Budget[4], times=T))
  s5=array(data=rep(Budget[5], times=T))
  s6=c(array(data=rep(Budget[6]+100, times=5)),array(data=rep(Budget[6], times=T-5)))
  s_plt=rbind(s1,s2,s3,s4,s5,s6)
  delta1=array(data=rep(1-(1-0.01)**(1/T), times=T)) # the following sets tolerance schedule for each trial
  delta2=array(data=rep(1-(1-0.2)**(1/T), times=T))
  delta3=delta1
  delta4=delta2
  delta5=c(array(data=rep(0.0001, times=5)),array(data=rep(0.001908, times=5)))
  delta6=delta1
  deltao_plt=rbind(delta1,delta2,delta3,delta4,delta5,delta6)
  plttype=2 # the following sets plotting parameters
  Tplt=10
  ymin=0
  ymax=800
  I=500
  line_color=c('black', 'red', 'blue','green','orange', 'brown')
  leg_pos='topleft'
  legendtxt=c('(-100, 0.01)', '(-100, 0.2)', '(-500, 0.01)', '(-500, 0.2)','ration tolerance', 'ration budget')
}else if (Figure1=='i') #Experiment (Thomp_budget_trend): NPTE, TOM, budget surplus
{
  T=10
  bandit='thompson'
  add_plt=c(0,1,2,3,4,5,6)
  legendtxt=c('c=0.001', 'c=0.007', 'c=0.01','c=0.02','c=0.05','c=0.1') # legend text
  dg_mode=2
  mu_true1_sq = pmin(seq(from = -2, by = 0.5, length.out = T),2)
  mu00=0
  mu01=-2
  var00=0.05
  var01=0.05
  c_arr=c(0.001,0.007,0.01,0.02,0.05,0.1)
  I=500
  plt_num=6
  plttype=2
  ymin=-1000
  ymax=1250
  line_color=c('black', 'red', 'blue','green','orange', 'brown', 'grey')
  Budget=rep(-500,plt_num)
}else if (Figure1=='j') # Experiment (BudgetD_con): hist, std Gaussian
{
  # do same simulation three times and output three plots
  T=10
  I=5000
  Budget=rep(-500,T)
  s_plt=c(-500, -500, -500) 
  deltao_plt=c(0.05, 0.05, 0.05) 
  plttype=3
  add_plt=c(-1,-1,-1)
  dg_mode=1
  mu_true0=1
  mu_true1=0
}else if (Figure1=='k') # Experiment (BudgetD_corr): hist, mv normal
{
  # do same simulation three times and output three plots
  T=10
  I=5000
  plttype=3
  Budget=rep(-500,T)
  s_plt=c(-500, -500, -500)
  deltao_plt=c(0.05, 0.05, 0.05)
  add_plt=c(-1,-1,-1)
  dg_mode=4
  mu_true=c(1,0)
  Cov_true=matrix(c(10,8,8,10),nrow = 2, byrow = TRUE)
}else if (Figure1=='l') # Experiment (BudgetD_bern): hist, bernoulli
{
  # do same simulation three times and output three plots
  T=10
  I=5000
  plttype=3
  Budget=rep(-500,T)
  s_plt=c(-500, -500, -500)
  deltao_plt=c(0.05, 0.05, 0.05)
  add_plt=c(-1,-1,-1)
  dg_mode=3
  mu_true0=0.5786
  mu_true1=0.4224
  mag1=1/(mu_true0-mu_true1)
  mag0=mag1
  varmd0_true=mag0**2*mu_true0*(1-mu_true0)
  varmd1_true=mag1**2*mu_true1*(1-mu_true1)
  varmd0=varmd0_true
  varmd1=varmd1_true
}else if (Figure1=='m') # Experiment (BudgetD_t): hist, t-dist
{
  # do same simulation three times and output three plots
  T=10
  I=5000
  plttype=3
  mu_true0=0
  mu_true1=1
  Budget=rep(-500,T)
  s_plt=c(-500, -500, -500)
  deltao_plt=c(0.05, 0.05, 0.05)
  add_plt=c(-1,-1,-1)
  dg_mode=5
  df=4
  mag=5
}else if (Figure1=='n') # Experiment (BudgetD_dec): hist, worsening update
{
  # do same simulation three times and output three plots
  T=10
  I=5000
  plttype=3
  Budget=rep(-500,T)
  s_plt=c(-500, -500, -500)
  mu_true1_sq = seq(from = 0, by = -1, length.out = T)
  deltao_plt=c(0.05, 0.05, 0.05)
  add_plt=c(-1,-1,-1)
  dg_mode=2
}


#### MAIN ALGORITHM ####
for (plt in 1:plt_num){
  if (is_dynamic==0 && bandit=='ours') # we do not use dynamic allocation of budget and tolerance
  {
    s= s_plt[plt] # total risk budget
    deltao=deltao_plt[plt] # risk tolerance
    delta=1-(1-deltao)**(1/T) # probability of exceeding risk budget at each stage
  }
  
  add=add_plt[plt] # plotting overlay
  
  mrec=matrix(nrow = I, ncol = T) # record results of ramp size for monte carlo trials
  budget_used=rep(0,I) # track budget used for entire experiment
  budgetsurplusrec=matrix(nrow = I, ncol = T) # track budget used at each stage
  
  I_vio=0 # track how many trials risk budget is exceeded/violated
  mTot=0 # track total number of treatment units obs
  obsTot=0 # track total treatment effect
  accTot0=0 # track estimation error control
  accTot1=0 # track estimation error treatment
  
  Z_dist<-Normal(0,1)
  
  for (i in 1:I)
  {
    
    if (is_print>0)
    {
      print(sprintf("********Trial: %s********", i))
    }
    
    # mupt0, vart0 is posterior mean and variance for control parameter (bf stage t)
    mupt0=mu00
    varpt0=var00
    
    # mupt1, vart1 is posterior mean and variance for treatment parameter (bf stage t)
    mupt1=mu01
    varpt1=var01
    
    obs_sum0=0 # cumulative sum of all observations under control (bf stage t)
    obs_sum1=0 # cumulative sum of all observations under treatment (bf stage t)
    m_sum=0 # total num of treatment observations cumulative (bf stage t)
    obs_sum0_sq=0 # cumulative sum of all observations squared under control (bf stage t)
    obs_sum1_sq=0 # cumulative sum of all observations squared under treatment (bf stage t)
    
    trteff=0 # cumulative treatment effect of treatment group
    
    
    for (t in 1:T)
    {
      if (is_dynamic==1 && bandit=='ours'){
        s= s_plt[plt,t] # total risk budget
        delta=deltao_plt[plt,t] # risk tolerance
      }
      
      N=Narr[t]
      if (is_print>0)
      {
        print(sprintf("###Stage: %s###", t))
        print(sprintf("Posterior mean,var control: %s,%s", mupt0, varpt0))
        print(sprintf("Posterior mean,var treatment: %s,%s", mupt1, varpt1))
        print(sprintf("Actual remaining blanace bf testing: %s", s-trteff))
        print(sprintf("obs_sum1,stilde, m_sum: %s, %s, %s", obs_sum1,stilde,m_sum))
        print(sprintf("Model remaining blanace bf testing: %s,%s", s-(obs_sum1-(mupt0-3*sqrt(varpt0))*m_sum),s-(obs_sum1-(mupt0+3*sqrt(varpt0))*m_sum)))
      }
      
      if (bandit=='ours')
      {
        # compute standard normal quantile
        q=quantile(Z_dist,delta)
        
        # compute stilde
        stilde=s-obs_sum1
        
        # compute how many to put in treatment for current stage
        A=q^2*(varpt1+varpt0)-(mupt1-mupt0)^2
        B=q^2*(varmd1+2*varpt0*m_sum+varmd0)+2*stilde*(mupt1-mupt0)+2*(mupt1-mupt0)*mupt0*m_sum
        C=q^2*varpt0*m_sum^2+q^2*varmd0*m_sum-stilde^2-2*stilde*mupt0*m_sum-(mupt0*m_sum)^2
        
        # In cse no solution; check if set mt=0 or N/2
        halfN=floor(N/2)
        mutildeN=mupt1*halfN-mupt0*(halfN+m_sum)
        vartildeN=halfN^2*varpt1+halfN*varmd1+(halfN+m_sum)^2*varpt0+(halfN+m_sum)*varmd0
        
        if(B^2-4*A*C<0)
        {
          if ((stilde-mutildeN)/sqrt(vartildeN)<q+0.0001)
          {
            m=halfN
          } 
          else
          {
            m=0
          }
        }
        else
        {
          m1=(-B-sqrt(B^2-4*A*C))/(2*A)
          m2=(-B+sqrt(B^2-4*A*C))/(2*A)
          
          # check which solution is correct (if any)
          mutildem1=mupt1*m1-mupt0*(m1+m_sum)
          vartildem1=m1^2*varpt1+m1*varmd1+(m1+m_sum)^2*varpt0+(m1+m_sum)*varmd0
          ism1=(stilde-mutildem1)/sqrt(vartildem1)<q+0.0001
          
          mutildem2=mupt1*m2-mupt0*(m2+m_sum)
          vartildem2=m2^2*varpt1+m2*varmd1+(m2+m_sum)^2*varpt0+(m2+m_sum)*varmd0
          ism2=(stilde-mutildem2)/sqrt(vartildem2)<q+0.0001
          
          if (ism1 & ism2)
          {
            m=max(m1,m2) # if both solution correct, take the larger one
          }else if (ism1 & !ism2)
          {
            m=m1 # take the correct solution
          } else if (!ism1 & ism2)
          {
            m=m2 # take the correct solution
          } else
          {
            # if both solution are wrong, check if set mt=0 or N/2
            if ((stilde-mutildeN)/sqrt(vartildeN)<q+0.0001)
            {
              m=halfN
            } 
            else
            {
              m=0
            }
          }
        }
        
        if(m<0)
        {
          # no solution; check if set mt=0 or N/2
          if ((stilde-mutildeN)/sqrt(vartildeN)<q+0.0001)
          {
            m=halfN
          } 
          else
          {
            m=0
          }
        }
        
        m=floor(m)
        
        # since we take m=floor(m), this corrects the bias of the floor function by taking m+1 with certain probability
        q1=cdf(Normal(mupt1*m-mupt0*(m+m_sum),sqrt(m^2*varpt1+m*varmd1+(m+m_sum)^2*varpt0+(m+m_sum)*varmd0)), stilde)
        ma=m+1
        q2=cdf(Normal(mupt1*ma-mupt0*(ma+m_sum),sqrt(ma^2*varpt1+ma*varmd1+(ma+m_sum)^2*varpt0+(ma+m_sum)*varmd0)), stilde)
        if (abs(q1-q2)>0.0001 & min(q1,q2)<delta & max(q1,q2)>delta)
        {
          p=(delta-q2)/(q1-q2)
          if (random(Bernoulli(p),1)==0)
          {
            m=ma
          }
        }
        if (max(q1,q2)<delta)
        {
          m=ma
        }
        
        # we will not treat more than half of available population
        if (m>halfN)
        {
          m=halfN
        }
      } else if (bandit=='thompson')
      {
        N=Narr[t]
        halfN=N/2
        p_1=pnorm((mupt1-mupt0)/sqrt(varpt0+varpt1))
        p_0=1-p_1
  
        c=c_arr[plt]
        m=floor(N*(p_1**c/(p_1**c+p_0**c)))
        if (m>halfN)
        {
          m=halfN
        }
        
      }
      
      if (is_print>0){
        print(sprintf("Treatment units: %s", m))
      }
      
      
      if (dg_mode==1){
        ## this mode is constant mean and Gaussian outcome
        if (is_print>0)
        {
          print(sprintf("true mean, variance of control: %s, %s", mu_true0, varmd0_true))
          print(sprintf("true mean, variance of treatment: %s, %s", mu_true1, varmd1_true))
        }
        
        # simulate testing result for current stage
        res1=random(Normal(mu_true1,sqrt(varmd1_true)),m) # treatment group obs
        res0=random(Normal(mu_true0,sqrt(varmd0_true)),N-m) # control group obs
        
        # control outcome of treatment group: never observed, used for testing purpose
        res10H=random(Normal(mu_true0,sqrt(varmd0_true)),m) # treatment group control
      }
      else if (dg_mode==2){
        ## the mode where treatment mean changes, Gaussian outcome
        mu_true0=0
        mu_true1=mu_true1_sq[t]
        
        if (is_print>0)
        {
          print(sprintf("true mean, variance of control: %s, %s", mu_true0, varmd0_true))
          print(sprintf("true mean, variance of treatment: %s, %s", mu_true1, varmd1_true))
        }
        
        # simulate testing result for current stage
        res1=random(Normal(mu_true1,sqrt(varmd1_true)),m) # treatment group obs
        res0=random(Normal(mu_true0,sqrt(varmd0_true)),N-m) # control group obs
        res10H=random(Normal(mu_true0,sqrt(varmd0_true)),m) # treatment group control
        
      } else if (dg_mode==3){
        ## binary outcome mode
        res1=mag1*rbinom(m, 1, mu_true1) # treatment obs
        res0=mag0*rbinom(N-m, 1, mu_true0)  # control group obs
        res10H=mag0*rbinom(m, 1, mu_true0) # treatment group control
      } else if (dg_mode==4){
        ## correlated Gaussian outcome mode
        X_all<-mvrnorm(N,mu_true, Cov_true)
        res1=X_all[1:m,2] # treatment obs
        res0=X_all[(m+1):N,1]# control group obs
        res10H=X_all[1:m,1] # treatment group control
      } else if (dg_mode==5)
      {
        ## fat tail, t-dist
        # simulate testing result for current stage
        res1=mu_true0+sqrt(mag)*rt(m, df) # treatment obs
        res0=mu_true1+sqrt(mag)*rt(N-m, df)  # control group obs
        res10H=1+sqrt(mag)*rt(m, df) # treatment group control
      } else if (dg_mode==6){
        ## LinkedIn, no outcome correlation
        mu_true0=mu_true0_sq[t]
        mu_true1=mu_true1_sq[t]
        varmd0_true=varmd0_sq[t]
        varmd1_true=varmd1_sq[t]
        res1=random(Normal(mu_true1,sqrt(varmd1_true)),m) # treatment group obs
        res0=random(Normal(mu_true0,sqrt(varmd0_true)),N-m) # control group obs
        res10H=random(Normal(mu_true0,sqrt(varmd0_true)),m) # treatment group control
      } else if (dg_mode==7){
        ## LinkedIn, no outcome correlation
        mu_true0=mu_true0_sq[t]
        mu_true1=mu_true1_sq[t]
        varmd0_true=varmd0_sq[t]
        varmd1_true=varmd1_sq[t]
        mu_true=c(mu_true0, mu_true1)
        Cov_true=matrix(c(varmd0_true,rho7,rho7,varmd1_true),nrow = 2, byrow = TRUE)
        X_all<-mvrnorm(N,mu_true, Cov_true)
        res1=X_all[1:m,2] # treatment obs
        res0=X_all[(m+1):N,1]# control group obs
        res10H=X_all[1:m,1] # treatment group control
      } 
      
      # update cumulative variables
      obs_sum1=obs_sum1+sum(res1)
      obs_sum0=obs_sum0+sum(res0)
      obs_sum1_sq=obs_sum1_sq+sum(res1**2)
      obs_sum0_sq=obs_sum0_sq+sum(res0**2)
      m_sum=m_sum+m
      trteff=trteff+sum(res1)-sum(res10H)
      mrec[i,t]=m
      budgetsurplusrec[i,t]=trteff-Budget[plt]
      
      # update posterior mean and variance for parameter
      if (emp_bayes==0){
        # assume data variance is varmd0, varmd1 is known
        # update posterior mean and variance
        mupt1=1/(1/var01+m_sum/varmd1)*(mu01/var01+obs_sum1/varmd1)
        varpt1=1/(1/var01+m_sum/varmd1)
        mupt0=1/(1/var00+(sum(Narr[1:t])-m_sum)/varmd0)*(mu00/var00+obs_sum0/varmd0)
        varpt0=1/(1/var00+(sum(Narr[1:t])-m_sum)/varmd0)
      } else
      {
        # estimate data variance varmd0, varmd1 empirically from data
        N_sum=sum(Narr[1:t])
        varmd0=obs_sum0_sq/(N_sum-m_sum)-(obs_sum0/(N_sum-m_sum))**2
        varmd1=obs_sum1_sq/m_sum-(obs_sum1/m_sum)**2
        
        # update posterior mean and variance
        mupt1=1/(1/var01+m_sum/varmd1)*(mu01/var01+obs_sum1/varmd1)
        varpt1=1/(1/var01+m_sum/varmd1)
        mupt0=1/(1/var00+(N_sum-m_sum)/varmd0)*(mu00/var00+obs_sum0/varmd0)
        varpt0=1/(1/var00+(N_sum-m_sum)/varmd0)
      }
      
    }
    
    if (is_print>0){
      print(sprintf("###Trial Summary###"))
      print(sprintf("Post. mean var control: %s, %s", mupt0,varpt0))
      print(sprintf("Post. mean var treatment: %s, %s", mupt1,varpt1))
      
      print(sprintf("Total treatment obs: %s", m_sum))
      
      print(sprintf("Sum of treatment effect in treatment: %s",trteff))
      print(sprintf("Total budget: no less than %s",s))
      print('***************************')
    }
    if (trteff<Budget[plt])
    {
      I_vio=I_vio+1
    }
    
    budget_used[i]=trteff
    
    mTot=mTot+m_sum
    obsTot=obsTot+trteff
    
    accTot0=accTot0+abs(mupt0-mu_true0)
    accTot1=accTot1+abs(mupt1-mu_true1)
  }
  
  if (1){
    print(sprintf("*******Experiment Summary*********"))
    print(sprintf("Prob. of violating budget: %s", I_vio/I))
    print(sprintf("Number units tested avg. per trial: %s",mTot/I))
    
    print(sprintf("Treatment effect sum avg. per trial: %s",obsTot/I))
    
    print(sprintf("Estimation accuracy treatment: %s",accTot1/I))
    print(sprintf("Estimation accuracy control: %s",accTot0/I))
    
    print("Ramp size at each stage")
    mavg=apply(mrec, 2, quantile, probs = 0.5)
    print(mavg)
    
    print("Budget spent at each stage")
    bavg=apply(budgetsurplusrec, 2, quantile, probs = 0.5)
    print(bavg)
    
    # compute 75, 25 quantile for treatment group size, budget surplus
    m_updev=apply(mrec, 2, quantile, probs = 0.75)
    m_downdev=apply(mrec, 2, quantile, probs = 0.25)
    b_updev=apply(budgetsurplusrec, 2, quantile, probs = 0.75)
    b_downdev=apply(budgetsurplusrec, 2, quantile, probs = 0.25)
    
    
    # plttype 1: plot ramp schedule
    if (add==0 && plttype==1){
      par(mar=c(3, 4, 2, 1))
      errbar(seq(1,T),mavg,m_updev,m_downdev,type='b',xlab='Stages',ylab="Treatment Group Size", ylim=c(ymin,ymax))
      
    }
    else if (add>0 && add<plt_num-1 && plttype==1){
      lines(seq(1, T), mavg, type = 'b', pch = 19, col=line_color[plt])
      arrows(seq(1, T), m_downdev, seq(1, T), m_updev, length = 0.03, angle = 90, code = 3, col=line_color[plt])
    }
    else if (add==plt_num-1 && plttype==1)
    {
      lines(seq(1, T), mavg, type = 'b', pch = 19, col=line_color[plt])
      arrows(seq(1, T), m_downdev, seq(1, T), m_updev, length = 0.03, angle = 90, code = 3, col=line_color[plt])
      if (Figure1!='d' && Figure1!='f')
      {
        legend(leg_pos, legend = legendtxt, col = line_color, pch = 19)
      }

    }
    
    # plttype 2: plot budget surplus at each stage
    if (add==0 && plttype==2){
      par(mar=c(3, 4, 2, 1))
      errbar(seq(1,Tplt),bavg[1:Tplt],b_updev[1:Tplt],b_downdev[1:Tplt],type='b',xlab='',ylab="Budget Surplus", ylim=c(ymin,ymax))
      
    }
    else if (add>0 && add<plt_num-1 && plttype==2){
      lines(seq(1, Tplt), bavg[1:Tplt], type = 'b', pch = 19, col=line_color[plt])
      arrows(seq(1, Tplt), b_downdev[1:Tplt], seq(1, Tplt), b_updev[1:Tplt], length = 0.03, angle = 90, code = 3, col=line_color[plt])
    }
    else if (add==plt_num-1 && plttype==2)
    {
      lines(seq(1, Tplt), bavg[1:Tplt], type = 'b', pch = 19, col=line_color[plt])
      arrows(seq(1, Tplt), b_downdev[1:Tplt], seq(1, Tplt), b_updev[1:Tplt], length = 0.03, angle = 90, code = 3, col=line_color[plt])
      legend(leg_pos, legend = legendtxt, col = line_color, pch = 19)
    }
    
    # plttype3: plot budget histogram
    if (plttype==3){
      print('most budget spent:')
      print(min(budget_used))
      print('std budget spent:')
      print(sd(budget_used))
      budget_useddf <- as.data.frame(budget_used)
      p=ggplot(data = budget_useddf, aes(x = budget_used)) +
        geom_histogram(binwidth = 20, color = "black", fill = "lightblue") +
        geom_vline(xintercept = s, color = "red", linetype = "dashed", size=2) +
        labs(x = "", y = "") +
        theme_bw()
      print(p)
    }
    
  }
}

# add actual linkedin experiment ramp after all plotting
if (Figure1=='d'){
  lines(seq(1, 6), mavg[1:6], type = 'b', pch = 19, col=line_color[plt], lwd=5)
  arrows(seq(1, T), m_downdev, seq(1, T), m_updev, length = 0.03, angle = 90, code = 3, col=line_color[plt], lwd=5)
  lines(seq(1, 6), c(107,	523,	1059,	1894,	5275,	5344), type = 'b',lwd=5, pch = 19, col='grey')
  legendtxt=c('(-400, 0.01)', '(-400, 0.2)', '(-1500, 0.01)', '(-1500, 0.2)','ration-budget Linkedin', 'actual LinkedIn')
  legend(leg_pos, legend = legendtxt, col = line_color, pch = 19)
}
if (Figure1=='f'){
  lines(seq(1, 6), c(107,	523,	1059,	1894,	5275,	5344), type = 'b',lwd=5, pch = 19, col='grey')
  legend(leg_pos, legend = c('c=0.01', 'c=0.05', 'c=0.1', 'c=0.15','c=0.17', 'c=0.2', 'actual LinkedIn'), col = line_color, pch = 19)
}
